Skip to content

feat(dflash): MoE 35B-A3B support + DDTree CUDA graph reuse#39

Open
dusterbloom wants to merge 11 commits into
Luce-Org:mainfrom
dusterbloom:feat/moe-35b-a3b
Open

feat(dflash): MoE 35B-A3B support + DDTree CUDA graph reuse#39
dusterbloom wants to merge 11 commits into
Luce-Org:mainfrom
dusterbloom:feat/moe-35b-a3b

Conversation

@dusterbloom
Copy link
Copy Markdown
Contributor

Summary

Adds Qwen3.5/3.6 35B-A3B MoE target support to the dflash spec-decode path, plus performance work on the DDTree verify graph and MoE AR.

MoE 35B-A3B (5 cycles)

  • `feat(dflash): accept qwen35moe arch in GGUF loader (MoE cycle 1)`
  • `feat(dflash): implement MoE FFN with expert routing + shared expert (cycle 2)`
  • `feat(dflash): parameterize target graph for 40-layer MoE + full forward test (cycle 3)`
  • `feat(dflash): parameterize draft model for 35B-A3B MoE + YaRN RoPE (cycle 4)`
  • `refactor(dflash): parameterize test code for multi-model support (cycle 5)`

Perf / fixes

  • `fix(dflash): reshape shared expert tensors for batched MoE FFN`
  • `perf(dflash): optimize MoE compute and graph allocator reuse`
  • `fix(dflash): correct YaRN attention scale for MoE draft model`
  • `perf(dflash): CUDA graph reuse + GPU argmax for MoE AR (64→144 tok/s)`
  • `perf(dflash): CUDA graph reuse for DDTree target verify`

Merge with main

This branch was synced with `origin/main` (32 upstream commits, including layer-segmented prefill + sliding-window FA, Blackwell/NVFP4 megakernel, and `--fa-window` CLI). Conflicts in `internal.h`, `qwen35_target_graph.cpp`, and `test_dflash.cpp` were resolved by:

  • keeping upstream's prefill_only / windowed FA cache + `build_qwen35_layer` helper
  • preserving the user-side MoE FFN entry point and reusable DDTree graph path
  • making the DDTree ancestor mask window-aware (uses `tree_win_start` consistent with `g_fa_window`)

Verified end-to-end on Qwen3.6-27B-Q4_K_XL + dflash-3.6 drafter:

  • `./build/test_dflash` builds clean, smoke binaries link
  • merge_sort prompt at temp=0 reproduces identical token output and per-step accept stats vs the pre-merge tip (400/912 = 43.9% accept, 7.02 tokens committed/step)
  • tps within measurement noise of pre-merge baseline (~85-90 tps on RTX 3090, ddtree budget=22)

Side note: comparison with the buun-llama-cpp fork

While testing this branch we ran the same prompt on https://github.com/spiritbuun/buun-llama-cpp (`Qwen3.6-27B-DFlash-GGUF` linear-chain spec-decode in upstream llama.cpp) for cross-implementation calibration:

Stack Drafter tps (avg, code prompt @ temp=0) Accept
buun-llama-cpp linear chain spiritbuun Q4_K_M 189 86.1%
buun-llama-cpp linear chain z-lab F16 (lucebox's drafter, converted) 125 78.5%
buun-llama-cpp linear chain z-lab Q4_K_M 149 71%
lucebox ddtree budget=16 z-lab F16 109 48.1%
lucebox ddtree budget=22 (default) z-lab F16 97 45.5%

Two takeaways relevant to lucebox:

  1. Drafter weights are a real lever. Holding the runtime constant (buun chain), spiritbuun's drafter delivers +7.6pt accept and ~+50% tps over z-lab's at F16. spiritbuun appears to have re-trained / fine-tuned on top of the z-lab release rather than just quantising it.

  2. DDTree budget=16 is faster than the default budget=22 on this prompt (109 vs 97 tps) — fewer redundant tree branches, slightly higher per-step accept (48.1% vs 45.5%). Worth considering as the default for short-context code-shaped prompts. Budgets ≤ block_size (16) crash with a ggml shape assertion in test_dflash.

  3. We attempted to add a linear-chain mode in lucebox via `--fast-rollback` (no `--ddtree`) but it consistently produced ~42% accept on the same drafter — substantially worse than buun's chain at 78.5% with the same weights. We've left that investigation on a separate branch (`session-debug-2026-04-26` on the fork) along with a new `test_chunked_vs_seq.cpp` regression that exercises `build_delta_net_chunked` against `ggml_gated_delta_net` and a scalar-C++ reference. The test currently fails at n_tokens=16 for all three paths against each other — so the disagreement is not uniquely a lucebox bug, but the test is a useful starting point for future GDA correctness work.

Test plan

  • `cmake --build build --target test_dflash` clean (with the merge fixes)
  • `scripts/server.py` boots, daemon ready, OpenAI endpoint responds
  • merge_sort code prompt produces sensible Python at temp=0
  • Per-step accept + commit-rate match pre-merge tip (no algorithmic regression)
  • MoE 35B-A3B end-to-end (covered by the user's existing local validation; no MoE model on the merge tester's box)

Add MoE tensor fields to TargetLayer (ffn_gate_inp, ffn_up_exps,
ffn_gate_exps, ffn_down_exps, shared expert tensors) and MoE hparams
to TargetWeights (n_expert, n_expert_used, expert_ff_dim, shared_ff_dim).

Update load_target_gguf() to accept both qwen35 (dense) and qwen35moe
architectures with separate validation paths. Add smoke_load_moe_target
test that loads Qwen3.6-35B-A3B and validates all 40 layers, 256 experts,
10 full-attn + 30 delta-net layers.

No regression: 27B loader still passes smoke_load_target.
…cycle 2)

Add build_moe_ffn() implementing full qwen35moe FFN path:
- Softmax gating over 256 experts, top-8 selection
- Per-expert SwiGLU via ggml_mul_mat_id
- Weight normalization and aggregation
- Shared expert path with sigmoid gating (ffn_gate_inp_shexp)

Tested with smoke_moe_ffn on Qwen3.6-35B-A3B: valid output,
no NaN/Inf, correct shape [2048, 1].
…rd test (cycle 3)

Replace all q35:: namespace constants with runtime reads from TargetWeights
so the same graph builder handles both 64-layer 27B and 40-layer 35B-A3B MoE.
Dynamic CAPTURE_LAYERS computation, MoE FFN branch, and dynamic cache sizing.
Full forward smoke test passes for both models with no regressions.
…ycle 4)

Add DraftHparams struct with config.json parsing for layer count, hidden size,
attention dims, and YaRN RoPE scaling params. Parameterize draft loader and
graph builder to handle both 5-layer 27B and 8-layer 35B-A3B drafts.
YaRN RoPE with factor=64, beta_fast=32, beta_slow=1 supported.
Both draft models pass forward smoke tests with no regressions.
…le 5)

Replace all DFLASH27B_TARGET_HIDDEN/VOCAB/DRAFT_BLOCK_SIZE/N_TARGET_LAYERS
macro usages in test_dflash.cpp and smoke_draft_graph.cpp with runtime reads
from loaded model weights. Enables the speculative decoding loop to run
with both 64-layer 27B and 40-layer 35B-A3B MoE models.
Reshape sh_gate/sh_up to 2D and sh_down to 2D before shared expert
gating broadcast, fixing ggml_can_repeat assertion when n_tokens > 1.
Chain speculative decoding: 78 tok/s, DDTree: 14 tok/s on RTX 3090.
Remove unnecessary ggml_repeat in shared expert gating (use ggml_mul
broadcast instead). Add ggml_gallocr_reserve for graph buffer reuse
and parameterize test_generate for both model sizes.

Benchmarks on RTX 3090 (target-only decode):
  27B Q4_K:  35.3 tok/s (llama.cpp: 36.5, gap: -3.3%)
  35B-A3B:   64.5 tok/s (llama.cpp: 85.0, gap: -24.1%)
The MoE draft model (factor=64 YaRN) was using attn_factor=1/(64^2)=1/4096
as the flash attention scale, making attention 4096x too weak. The Python
reference uses standard 1/sqrt(head_dim) — YaRN correction belongs only in
the RoPE cos/sin multipliers (ggml_rope_ext mscale param), not the attention
scale. Also fixed the RoPE mscale from 1/factor^2 to the correct YaRN
formula: 1/(0.1*ln(factor)+1) = 0.706 for factor=64.

HumanEval DDTree benchmark (RTX 3090, budget=22):
  MoE 35B-A3B: 19.1 -> 53.0 tok/s (2.8x improvement)
  27B:         81.2 tok/s (no change, factor=1 unaffected)
Enable CUDA graphs and rewrite test_generate with fixed-graph architecture:
- K/V written to fixed scratch slot (max_ctx-1), copied to correct position
  after compute so graph structure never changes between decode steps
- F16 attention mask input for variable-length causal attention
- ggml_argmax in graph eliminates GPU→CPU logits transfer per step
- CUDA graph replay eliminates ~1000 kernel launches per decode step

Results on RTX 3090 (Qwen3.6-35B-A3B Q2_K):
  MoE AR: 64.5 → 143.8 tok/s (+123%, now 1.7x faster than llama.cpp)
  27B AR: 35.3 → 41.9 tok/s (+19%)
  27B DDTree: 85.4 → 83.3 tok/s (no regression)
Add build_target_step_tree_reusable() with fixed kv_start and n_tokens
so CUDA graphs can replay across DDTree decode steps. K/V and target
features are written to scratch slots (max_ctx - budget - 1 .. max_ctx - 2)
and copied to committed positions after verify.

Results on RTX 3090:
  MoE DDTree: 53.8 → 55.8 tok/s (+3.7%, limited by 22% acceptance)
  27B DDTree: 83.3 → 80.8 tok/s (no regression, within noise)
# Conflicts:
#	dflash/deps/llama.cpp
#	dflash/src/internal.h
#	dflash/src/qwen35_target_graph.cpp
#	dflash/test/test_dflash.cpp
@davide221
Copy link
Copy Markdown
Contributor

@dusterbloom can you rebase? It would be cool to have Qwen MoE available for different use cases from Qwen 27b

javierpazo added a commit to javierpazo/lucebox-hub that referenced this pull request May 10, 2026
This change brings concurrent multi-request execution to test_dflash
on a single GPU. It is internally one cohesive unit but can be split
into four conceptual pieces if a smaller review is preferred:

1. Multi TargetCache slots
   - CLI: --target-cache-slots=N (alias --cache-slots=N)
   - prefix `SLOT <id>` routes commands to a specific slot
   - DaemonSlotState + RAII ActiveDaemonSlot for safe switching
   - LIST_TARGET_CACHE_SLOTS for introspection
   - all slots share target/draft weights; only KV/SSM/scratch is
     per-slot
   - create_target_cache gains an `n_seqs` parameter so a single
     cache can be allocated batched up front

2. Tagged stream protocol (opt-in)
   - --stream-tagged emits frames `[-2, request_id, token]` instead
     of bare int32 tokens; sentinels `-4` (CONTINUE), `-1` (DONE)
   - parser recognises `REQ <id>` / `REQUEST <id>` headers
   - legacy bare-int32 streaming is unchanged when the flag is off
   - this lets a client demux multiple concurrent requests over the
     same stdout

3. Native quantum scheduler
   - dispatch table for REQ/SLOT/START, SCHED_STEP, SCHED_DRAIN,
     LIST_REQUESTS
   - cursor-based fair round-robin between admitted requests
   - non-blocking reader thread admits new requests during a drain
   - PendingQuantum{slot, req, epoch, n_gen} carries the unit of work
   - CONTINUE / CONT resumes a slot without re-prefilling
   - REQ <id> CANCEL invalidates a request and bumps the slot epoch
     so a stale CONTINUE is rejected; RESTORE_CHAIN / legacy generate
     refuse to overwrite a slot that is owned by an active scheduler
     request

4. Fused batched target step (CUDA path)
   - new commands: SCHED_BATCH_PEEK, SCHED_BATCH_PROBE,
     SCHED_BATCH_TARGET_TAIL, SCHED_BATCH_TARGET_STEP,
     SCHED_BATCH_DRAIN
   - QwenGraphInputs gains `n_seqs`; build_delta_net_block accepts
     n_seqs > 1
   - target_feat is allocated as [5*hidden, target_feat_cap, n_seqs]
     when batched and the chain forwards capture features per-seq
   - batch_probe_compare_ok smoke shows mismatches=0 vs the
     single-seq path; SCHED_BATCH_TARGET_TAIL commits two completed
     pending quanta in 29.26 ms; SCHED_BATCH_TARGET_STEP commits the
     next batched step in 29.57 ms; SCHED_BATCH_DRAIN completes
     req12/req13 with two batched steps each
   - rollback for partially accepted draft tokens, multi-token verify
     and parent-id propagation in the batched path are noted as
     follow-ups; today the batched step accepts the cleanest case
     and falls back to single-seq when needed

Validation (single GPU1 RTX 6000 Ada sm_89, Heretic Q4_K_M target +
Q8 GGUF or FP16 safetensors drafter, FA_WINDOW=0, KV q4_0/q4_0):

- Two concurrent requests:
    REQ 4 START SLOT 0 quantum=2
    REQ 5 START SLOT 1 quantum=2
    SCHED_DRAIN closes both clean.
    slot 0: 18.41 tok/s, slot 1: 22.50 tok/s
- Mid-drain admission of REQ 6 succeeds; CONTINUE on slot 0 resumes
  without re-prefill.
- batch_probe_compare_ok mismatches=0 over a 2-seq probe.
- batch_tail_commit count=2 ms=29.26.
- batch_step_commit ms=29.57 followed by SCHED_DRAIN reverts cleanly
  back to the DFlash single-seq path.

Compatibility:
- All new behaviour is opt-in. Default invocation of test_dflash
  with no scheduler flags keeps the legacy single-request path.
- Tagged stream is gated behind --stream-tagged.
- Multi-slot is gated behind --target-cache-slots=N (default N=1).
- Batched target step is reached only via the SCHED_BATCH_* command
  family; legacy SCHED_STEP keeps using the single-seq path.
- Hot-loop diagnostic logs (sync_us / step_debug) are now gated
  behind DFLASH27B_TIMING_DEBUG / DFLASH27B_STEP_DEBUG so the
  default path is unchanged.

Verification vs existing community PRs:
- No prior art in lucebox-hub for the SCHED_BATCH_* protocol or for
  a native C++ quantum scheduler with REQ/SLOT/CONTINUE/CANCEL +
  epoch hardening. Checked against PR Luce-Org#39 (CUDA graph reuse) and
  PR Luce-Org#62 (split target/draft StepGraphs); both reuse / split graphs
  but neither exposes a multi-request slot protocol.
- No upstream collision found for tagged stream framing or
  --target-cache-slots.

Happy to split this into four sequential PRs (slots / tagged stream /
quantum scheduler / batched target step) if a smaller-grained review
is preferred — let me know.

Author: Javier Pazo <xabicasa@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants